{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.pylab as pylab\n",
    "import matplotlib.cm as cm\n",
    "%matplotlib inline\n",
    "import scipy.misc\n",
    "from PIL import Image\n",
    "import scipy.io\n",
    "import os\n",
    "import cv2\n",
    "import time\n",
    "\n",
    "# Make sure that caffe is on the python path:\n",
    "caffe_root = '../../'  # this file is expected to be in {caffe_root}/examples/hed/\n",
    "import sys\n",
    "sys.path.insert(0, caffe_root + 'python')\n",
    "\n",
    "import caffe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "data_root = '../../data/HED-BSDS/'\n",
    "with open(data_root+'test.lst') as f:\n",
    "    test_lst = f.readlines()\n",
    "    \n",
    "test_lst = [x.strip() for x in test_lst]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "im_lst = []\n",
    "for i in range(0, len(test_lst)):\n",
    "    im = Image.open(data_root+test_lst[i])\n",
    "    in_ = np.array(im, dtype=np.float32)\n",
    "    in_ = in_[:,:,::-1]\n",
    "    in_ -= np.array((104.00698793,116.66876762,122.67891434))\n",
    "    im_lst.append(in_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#Visualization\n",
    "def plot_single_scale(scale_lst, size):\n",
    "    pylab.rcParams['figure.figsize'] = size, size/2\n",
    "    \n",
    "    plt.figure()\n",
    "    for i in range(0, len(scale_lst)):\n",
    "        s=plt.subplot(1,5,i+1)\n",
    "        plt.imshow(1-scale_lst[i], cmap = cm.Greys_r)\n",
    "        s.set_xticklabels([])\n",
    "        s.set_yticklabels([])\n",
    "        s.yaxis.set_ticks_position('none')\n",
    "        s.xaxis.set_ticks_position('none')\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#remove the following two lines if testing with cpu\n",
    "caffe.set_mode_gpu()\n",
    "caffe.set_device(0)\n",
    "# load net\n",
    "net = caffe.Net('test.prototxt', 'rcf_pretrained_bsds.caffemodel', caffe.TEST)\n",
    "\n",
    "save_root = os.path.join(data_root, 'test-fcn')\n",
    "if not os.path.exists(save_root):\n",
    "    os.mkdir(save_root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Detection took 0.036s per image\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "for idx in range(0, len(test_lst)):\n",
    "    in_ = im_lst[idx]\n",
    "    in_ = in_.transpose((2, 0, 1))\n",
    "    \n",
    "    # shape for input (data blob is N x C x H x W), set data\n",
    "    net.blobs['data'].reshape(1, *in_.shape)\n",
    "    net.blobs['data'].data[...] = in_\n",
    "    # run net and take argmax for prediction\n",
    "    net.forward()\n",
    "    \n",
    "    # save results\n",
    "    #out1 = net.blobs['sigmoid-dsn1'].data[0][0, :, :]\n",
    "    #out2 = net.blobs['sigmoid-dsn2'].data[0][0, :, :]\n",
    "    #out3 = net.blobs['sigmoid-dsn3'].data[0][0, :, :]\n",
    "    #out4 = net.blobs['sigmoid-dsn4'].data[0][0, :, :]\n",
    "    #out5 = net.blobs['sigmoid-dsn5'].data[0][0, :, :]\n",
    "    fuse = net.blobs['sigmoid-fuse'].data[0][0, :, :]\n",
    "    #out1 = 255 * (1-out1)\n",
    "    #out2 = 255 * (1-out2)\n",
    "    #out3 = 255 * (1-out3)\n",
    "    #out4 = 255 * (1-out4)\n",
    "    #out5 = 255 * (1-out5)\n",
    "    fuse = 255 * (1-fuse)\n",
    "    #cv2.imwrite(save_root + '/' + test_lst[idx][5:-4] + '_out1.png', out1)\n",
    "    #cv2.imwrite(save_root + '/' + test_lst[idx][5:-4] + '_out2.png', out2)\n",
    "    #cv2.imwrite(save_root + '/' + test_lst[idx][5:-4] + '_out3.png', out3)\n",
    "    #cv2.imwrite(save_root + '/' + test_lst[idx][5:-4] + '_out4.png', out4)\n",
    "    #cv2.imwrite(save_root + '/' + test_lst[idx][5:-4] + '_out5.png', out5)\n",
    "    cv2.imwrite(save_root + '/' + test_lst[idx][5:-4] + '_fuse.png', fuse)\n",
    "    \n",
    "diff_time = time.time() - start_time\n",
    "print 'Detection took {:.3f}s per image'.format(diff_time/len(test_lst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
